from typing import Optional, Dict, Any

import pydantic


LANGCHAIN_USAGE_KEYS = {"input_tokens", "output_tokens", "total_tokens"}


class InputTokenDetails(pydantic.BaseModel):
    model_config = pydantic.ConfigDict(extra="allow")

    cache_read: Optional[int] = None
    """Number of tokens read from the cache."""

    cache_creation: Optional[int] = None
    """Number of tokens created in the cache."""

    audio: Optional[int] = None
    """Number of tokens in the audio prompt."""


class OutputTokenDetails(pydantic.BaseModel):
    model_config = pydantic.ConfigDict(extra="allow")

    reasoning: Optional[int] = None
    """Number of tokens generated by the model for reasoning."""

    audio: Optional[int] = None
    """Number of tokens in the audio response."""


class LangChainUsage(pydantic.BaseModel):
    # """
    # LangChain Usage class, based on, but not inherited from UsageMetadata class in langchain_code.messages.ai.UsageMetadata
    # https://python.langchain.com/api_reference/core/messages/langchain_core.messages.ai.UsageMetadata.html#usagemetadata
    # """
    model_config = pydantic.ConfigDict(extra="allow")

    input_tokens: int
    """Number of tokens in the prompt."""

    output_tokens: int
    """Number of tokens in the response(s)."""

    total_tokens: int
    """Total token count for prompt, response candidates, and reasoning."""

    input_token_details: Optional[InputTokenDetails]
    """Breakdown of tokens used in the prompt."""

    output_token_details: Optional[OutputTokenDetails]
    """Breakdown of tokens used in a response."""

    @classmethod
    def from_original_usage_dict(cls, usage: Dict[str, Any]) -> "LangChainUsage":
        usage_dict = {**usage}
        input_token_details_raw = usage_dict.pop("input_token_details", None)
        output_token_details_raw = usage_dict.pop("output_token_details", None)

        input_token_details = (
            InputTokenDetails(**input_token_details_raw)
            if isinstance(input_token_details_raw, dict)
            else None
        )

        output_token_details = (
            OutputTokenDetails(**output_token_details_raw)
            if isinstance(output_token_details_raw, dict)
            else None
        )

        return cls(
            **usage_dict,
            input_token_details=input_token_details,
            output_token_details=output_token_details,
        )

    def map_to_google_gemini_usage(self) -> Dict[str, Any]:
        google_usage: Dict[str, Any] = {
            "prompt_token_count": self.input_tokens,
            "candidates_token_count": self.output_tokens,
            "total_token_count": self.total_tokens,
        }
        if self.input_token_details is not None:
            google_usage["cached_content_token_count"] = (
                self.input_token_details.cache_read
            )

        if self.output_token_details is not None:
            google_usage["thoughts_token_count"] = self.output_token_details.reasoning

        return google_usage

    def map_to_anthropic_usage(self) -> Dict[str, Any]:
        anthropic_usage: Dict[str, Any] = {
            "input_tokens": self.input_tokens,
            "output_tokens": self.output_tokens,
            "total_tokens": self.total_tokens,
        }
        if self.input_token_details is None:
            return anthropic_usage

        if self.input_token_details.cache_creation is not None:
            anthropic_usage["cache_creation_input_tokens"] = (
                self.input_token_details.cache_creation
            )
        if self.input_token_details.cache_read is not None:
            anthropic_usage["cache_read_input_tokens"] = (
                self.input_token_details.cache_read
            )

        return anthropic_usage

    def map_to_bedrock_usage(self) -> Dict[str, Any]:
        bedrock_usage: Dict[str, Any] = {
            "inputTokens": self.input_tokens,
            "outputTokens": self.output_tokens,
        }

        if self.input_token_details is not None:
            bedrock_usage["cacheReadInputTokens"] = self.input_token_details.cache_read
            bedrock_usage["cacheWriteInputTokens"] = (
                self.input_token_details.cache_creation
            )

        return bedrock_usage

    def map_to_openai_completions_usage(self) -> Dict[str, Any]:
        openai_usage: Dict[str, Any] = {
            "prompt_tokens": self.input_tokens,
            "completion_tokens": self.output_tokens,
            "total_tokens": self.total_tokens,
        }

        if self.input_token_details is not None:
            openai_usage["prompt_tokens_details"] = {
                "cached_tokens": self.input_token_details.cache_read,
                "audio_tokens": self.input_token_details.audio,
            }

        if self.output_token_details is not None:
            openai_usage["completion_tokens_details"] = {
                "audio_tokens": self.output_token_details.audio,
                "reasoning_tokens": self.output_token_details.reasoning,
            }

        return openai_usage

    def map_to_groq_completions_usage(self) -> Dict[str, Any]:
        groq_usage: Dict[str, Any] = {
            "prompt_tokens": self.input_tokens,
            "completion_tokens": self.output_tokens,
            "total_tokens": self.total_tokens,
        }

        return groq_usage


def is_langchain_usage(usage_dict: Dict[str, Any]) -> bool:
    if usage_dict is None:
        return False

    return all(key in usage_dict for key in LANGCHAIN_USAGE_KEYS)
